package com.simafei.flow.core;

import cn.hutool.core.collection.CollectionUtil;
import com.simafei.flow.core.common.ExecStatus;
import com.simafei.flow.core.common.FlowStatus;
import com.simafei.flow.core.common.NodeType;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;

/**
 * 节点
 *
 * @author fengpengju
 */
@ToString(exclude = {"outEdges", "inEdges"})
@Slf4j
public class Node {

    /**
     * 节点Id
     */
    @Setter
    @Getter
    private String id;

    /**
     * 节点名称
     */
    @Setter
    @Getter
    private String name;

    /**
     * 是否阻塞
     */
    @Setter
    @Getter
    private boolean blocked;

    /**
     * 透传字段
     */
    @Setter
    @Getter
    private Map<String, Object> payload;

    /**
     * 节点类型
     */
    @Getter
    private final NodeType nodeType;

    /**
     * 出边
     */
    @Getter
    private final List<Edge> outEdges;

    /**
     * 入边
     */
    @Getter
    private final List<Edge> inEdges;

    /**
     * 当前节点是否已执行
     */
    private final AtomicBoolean executed = new AtomicBoolean(false);


    public Node(NodeType nodeType) {
        this.nodeType = nodeType;
        this.outEdges = new ArrayList<>();
        this.inEdges = new ArrayList<>();
    }

    public final void addOutEdge(Edge edge) {
        outEdges.add(edge);
    }

    public final void addInEdge(Edge edge) {
        inEdges.add(edge);
    }


    public CompletableFuture<Void> execute(ExecutionContext context, List<Map<String, Object>> params) {
        return execute(context, params, null);
    }

    public final CompletableFuture<Void> execute(ExecutionContext context, List<Map<String, Object>> params, String parentId) {
        if (blocked) {
            // 如果是阻塞节点，则判断所有入边是否执行完毕，如果执行完毕，则执行
            // 用原子变量确保多线程情况下isAllInEdgesExecuted()都返回true的情况下只执行一次
            if (isAllInEdgesExecuted(context) && !executed.getAndSet(true)) {
                List<Map<String, Object>> mergedParams = context.getEdgeResults().stream()
                        .map(EdgeResult::getFilteredParams).flatMap(Collection::stream).toList();
                return executeAsync(context, mergedParams, parentId);
            } else {
                return CompletableFuture.completedFuture(null);
            }
        } else {
            return executeAsync(context, params, parentId);
        }
    }

    /**
     * 执行节点
     *
     * @param context 上下文
     * @param params  参数
     * @param parentId  父节点Id，可为空
     * @return 节点执行结果
     */
    public final CompletableFuture<Void> executeAsync(ExecutionContext context, List<Map<String, Object>> params, String parentId) {
        return CompletableFuture.supplyAsync(() -> {
            Optional.ofNullable(context.getExecutionListener())
                    .ifPresent(listener -> listener.beforeNodeExecute(context, Node.this));
            NodeResult nodeResult = executeWithHandling(context, params, parentId);
            subscribe(context, this, nodeResult);
            return nodeResult;
        }).thenComposeAsync(result -> {
            if (!result.isSuccess()) {
                context.setExecStatus(result.getExecStatus());
                context.setFlowStatus(FlowStatus.REJECT);
            } else if (result.getExecStatus() == ExecStatus.WAITING) {
                // 如果已经到达了结束节点，则不能将执行状态设置为等待(后续可配置为策略)
                if (context.getExecStatus() != ExecStatus.FINISHED) {
                    context.setExecStatus(ExecStatus.WAITING);
                }
            } else if (CollectionUtil.isNotEmpty(outEdges)) {
                // 合并结果和入参，作为下一步的输入(连线判断和节点执行)
                List<Map<String, Object>> nextInputParams = mergeToNextInputParams(result);

                List<CompletableFuture<Void>> futures = outEdges.stream()
                        .filter(edge -> Objects.nonNull(edge.getTo()))
                        // 转化为边执行结果(不是所有的入参都能通过边)
                        .map(edge -> executeEdge(context, edge, nextInputParams))
                        // 只有边通过的下游节点才执行(所有入参只要有一条数据通过也能执行下游节点)
                        .filter(EdgeResultWrapper::isPass)
                        // 执行下游节点
                        .map(wrapper -> wrapper.executeTo(context))
                        .toList();
                return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
            }
            return CompletableFuture.completedFuture(null);
        });
    }

    /**
     * 判断所有入边是否执行完毕
     */
    private boolean isAllInEdgesExecuted(ExecutionContext context) {
        if (CollectionUtil.isEmpty(context.getEdgeResults())) {
            return false;
        }
        Set<String> executedEdgeIds = context.getEdgeResults()
                .stream().map(EdgeResult::getEdgeId).collect(Collectors.toSet());

        Set<String> inEdgeIds = inEdges.stream().map(Edge::getId).collect(Collectors.toSet());

        // 判断inEdgeIds是否全部执行，当执行完的边包含了全部的inEdgeIds，则说明全部执行完
        return executedEdgeIds.containsAll(inEdgeIds);
    }

    private List<Map<String, Object>> mergeToNextInputParams(NodeResult result) {
        List<Map<String, Object>> nextNodeInputParams = new ArrayList<>();

        List<ExecutionResult> execResults = result.getResults();
        List<Map<String, Object>> inputParams = result.getInputParams();

        if (execResults.isEmpty()) {
            // 没有执行结果，直接用输入参数作为输出参数(指变量和结果都没有)
            for (Map<String, Object> inputParam : inputParams) {
                nextNodeInputParams.add(nextStepParam(inputParam, false));
            }
        } else if (inputParams.size() > execResults.size()) {
            // 做了数据聚合，将不能用前面的参数，只能用执行结果作为参数
            for (ExecutionResult execResult : execResults) {
                for (Map<String, Object> execResultResult : execResult.getResults()) {
                    nextNodeInputParams.add(nextStepParam(execResultResult, true));
                }
            }
        } else {
            for (ExecutionResult execResult : execResults) {
                // 参数合并，保证后面的节点能用到前面的入参
                if (CollectionUtil.isEmpty(execResult.getResults())) {
                    // 如果没有执行结果(有变量，但是没有查询导数据)，直接用输入参数作为输出参数
                    nextNodeInputParams.add(nextStepParam(execResult.getInputParams(), false));
                    continue;
                }
                // 如果有执行结果，将执行结果和输入参数合并
                for (Map<String, Object> execResultResult : execResult.getResults()) {
                    Map<String, Object> newResults = nextStepParam(execResult.getInputParams(), true);
                    newResults.putAll(execResultResult);
                    nextNodeInputParams.add(newResults);
                }
            }
        }
        return nextNodeInputParams;
    }

    private Map<String, Object> nextStepParam(Map<String, Object> inputParam, boolean hasResult) {
        Map<String, Object> newResults = new HashMap<>(inputParam);
        newResults.put(FlowConstants.HAS_RESULT, hasResult ? FlowConstants.YES : FlowConstants.NO);
        return newResults;
    }

    private void subscribe(ExecutionContext context, Node node, NodeResult result) {
        context.collectResult(result);
        Optional.ofNullable(context.getExecutionListener())
                .ifPresent(listener -> listener.afterNodeExecute(context, node, result));
    }

    protected NodeResult executeWithHandling(ExecutionContext context, List<Map<String, Object>> params, String parentId) {
        long startTime = System.currentTimeMillis();

        NodeResult result;
        try {
            result = resultBuilder()
                    .success(true)
                    .inputParams(params)
                    .startTime(startTime)
                    .results(doExecute(context, params))
                    .execStatus(ExecStatus.FINISHED)
                    .flowStatus(context.getFlowStatus())
                    .endTime(System.currentTimeMillis())
                    .parentId(parentId)
                    .build();

        } catch (Throwable e) {
            log.error("Node execute error, node: {}", name, e);
            result = resultBuilder()
                    .startTime(startTime)
                    .success(false)
                    .cause(e)
                    .execStatus(ExecStatus.FAILED)
                    .flowStatus(FlowStatus.REJECT)
                    .endTime(System.currentTimeMillis())
                    .parentId(parentId)
                    .build();
        }
        return result;
    }

    protected List<ExecutionResult> doExecute(ExecutionContext context, List<Map<String, Object>> params) {
        return List.of();
    }

    private NodeResult.NodeResultBuilder resultBuilder() {
        return NodeResult.builder()
                .name(name)
                .nodeId(id)
                .nodeType(nodeType);
    }

    private EdgeResultWrapper executeEdge(ExecutionContext context, Edge edge, List<Map<String, Object>> params) {
        EdgeResultWrapper wrapper = new EdgeResultWrapper();
        wrapper.edge = edge;
        wrapper.edgeResult = edge.execute(context, params);
        return wrapper;
    }

    private static class EdgeResultWrapper {
        private EdgeResult edgeResult;

        @Getter
        private Edge edge;

        boolean isPass() {
            return edgeResult.isPass();
        }

        CompletableFuture<Void> executeTo(ExecutionContext context) {
            // 执行下游节点，只有满足边的条件的入参才执行
            return edge.getTo().execute(context, edgeResult.getFilteredParams(), edge.getId());
        }
    }
}
